#include <iostream>
#include <fstream>
#include <opencv2/opencv.hpp>
 
void colorizeSegmentation(const cv::Mat &score, cv::Mat &segm, std::vector<cv::Vec3b> &colors) {
    // score.size = 1 x 20 x 512 x 1024
    const int chns = score.size[1];
    const int rows = score.size[2];
    const int cols = score.size[3];
 
    cv::Mat maxCl = cv::Mat::zeros(rows, cols, CV_8UC1);
    cv::Mat maxVal(rows, cols, CV_32FC1, score.data);
    for (int ch = 1; ch < chns; ch++)
    {
        for (int row = 0; row < rows; row++)
        {
            const float *ptrScore = score.ptr<float>(0, ch, row);
            uint8_t *ptrMaxCl = maxCl.ptr<uint8_t>(row);
            float *ptrMaxVal = maxVal.ptr<float>(row);
            for (int col = 0; col < cols; col++)
            {
                if (ptrScore[col] > ptrMaxVal[col])
                {
                    ptrMaxVal[col] = ptrScore[col];
                    ptrMaxCl[col] = (uchar)ch;
                }
            }
        }
    }
 
    segm.create(rows, cols, CV_8UC3);
    for (int row = 0; row < rows; row++) {
        for (int col = 0; col < cols; col++) {
            segm.at<cv::Vec3b>(row, col) = colors[maxCl.at<uchar>(row, col)];
        }
    }
}
 
int main(int argc, char** argv) {
    // load the class label names
    std::vector<std::string> classes;
    std::string classesFile = "../enet-cityscapes/enet-classes.txt";
    std::ifstream classNamesFile(classesFile.c_str());
    if (classNamesFile.is_open()) {
        std::string className = "";
        while (std::getline(classNamesFile, className)) {
            classes.push_back(className);
        }
    }
    else {
        std::cout << "can not open class label file" <<std::endl;
        return EXIT_FAILURE;
    }
 
    // if a colors file was supplied, load it from disk
    std::vector<cv::Vec3b> colors;
    std::string colorsFile = "../enet-cityscapes/enet-colors.txt";
    std::ifstream colorNamesFile(colorsFile.c_str());
    if (colorNamesFile.is_open()) {
        std::string line;
        while (std::getline(colorNamesFile, line)) { // 解析不够优雅
            auto pos1 = line.find_first_of(',');
            auto pos2 = line.find_last_of(',');
            int b = std::stoi(line.substr(0, pos1));
            int g = std::stoi(line.substr(pos1+1, pos2-pos1-1));
            int r = std::stoi(line.substr(pos2+1, line.size()-pos2-1));
            colors.push_back(cv::Vec3b(b, g, r));
        }
    } else { // otherwise, we need to randomly generate RGB colors for each class label
	    // initialize a list of colors to represent each class label in the mask
	    // (starting with 'black' for the background/unlabeled regions)
        colors.push_back(cv::Vec3b());
        for (int i = 1; i < classes.size(); ++i) {
            cv::Vec3b color;
            for (int j = 0; j < 3; ++j) {
                color[j] = rand() % 256;
            }
            colors.push_back(color);
        }
    }
 
    // initialize the legend visualization
    cv::Mat legend(colors.size() * 25 + 25, 300, CV_8UC3);
 
    // loop over the class names + colors
    for (int i = 0; i < classes.size(); ++i) {
        // draw the class name + color on the legend
        cv::putText(legend, classes[i], cv::Point(5, (i * 25) + 17),
                    cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 0, 255), 2);
        cv::rectangle(legend, cv::Point(100, (i * 25)), cv::Point(300, (i * 25) + 25), colors[i], -1);
    }
    cv::imshow("legend", legend);
    cv::waitKey(1);
 
    // load our serialized model from disk
    std::cout << "loading model..." << std::endl;
    cv::dnn::Net net = cv::dnn::readNet("../enet-cityscapes/enet-model.net");
    //net.setPreferableTarget(cv::dnn::DNN_TARGET_OPENCL);
 
    cv::VideoCapture cap("../enet-cityscapes/toronto.mp4");
    cv::Mat image;
 
    while (cv::waitKey(1) != 27) // 'Esc'
    {
        // load the input image, resize it, and construct a blob from it,
        // but keeping mind mind that the original input image dimensions
        // ENet was trained on was 1024x512
        cap >> image;
        cv::Mat blob;
        cv::dnn::blobFromImage(image, blob, 1/255.0, cv::Size(1024, 512), cv::Scalar(0, 0, 0), true, false);
 
        // perform a forward pass using the segmentation model
        net.setInput(blob);
 
        cv::Mat output = net.forward();
 
        cv::Mat mask;
        colorizeSegmentation(output, mask, colors);
 
        // resize the mask such that its dimensions match the original size of the input image
        cv::resize(mask, mask, image.size(), 0, 0, cv::INTER_NEAREST);
 
        // perform a weighted combination of the input image with the mask to form an output visualization
        cv::addWeighted(image, 0.4, mask, 0.6, 0, output);
 
        // Put efficiency information.
        std::vector<double> layersTimes;
        double freq = cv::getTickFrequency() / 1000;
        double t = net.getPerfProfile(layersTimes) / freq;
        std::string label = cv::format("Inference time: %.2f ms", t);
        cv::putText(output, label, cv::Point(0, 15), cv::FONT_HERSHEY_SIMPLEX, 0.5, cv::Scalar(0, 255, 0));
 
        //show the input and output images
        cv::imshow("Input", image);
        cv::imshow("Output", output);
        cv::waitKey(0);
    }
    cap.release();
    return EXIT_SUCCESS;
}